#!/usr/bin/env python
# -*- coding: utf-8 -*-
""" Title """


from collections import OrderedDict
from itertools import product
import math
import luigi
from rdkit import Chem
from numpy.random import default_rng
import pandas as pd
import torch
from torch import nn

MAX_INT = 2 ** 31 - 1

class RngMixin:

    @property
    def rng(self):
        if not hasattr(self, '_rng'):
            raise ValueError('please call `set_seed` first.')
        return self._rng

    def set_seed(self, seed):
        self._rng = default_rng(seed)

    def seed_list(self, n_seed):
        return self.rng.integers(
            MAX_INT,
            size=n_seed)

    def gen_seed(self):
        return int(self.rng.integers(
            MAX_INT,
            size=1))

    def seeding_iter(self, param_dict, n_iter):
        def _substitute_seed(param_dict):
            out_param_dict = dict()
            for each_key, each_val in param_dict.items():
                if isinstance(each_val, (dict,
                                         OrderedDict,
                                         luigi.freezing.FrozenOrderedDict)):
                    out_param_dict[each_key] = _substitute_seed(each_val)
                elif each_key == 'seed':
                    out_param_dict[each_key] = self.gen_seed()
                else:
                    out_param_dict[each_key] = each_val
            return out_param_dict

        param_dict_list = []
        for _ in range(n_iter):
            param_dict_list.append(_substitute_seed(param_dict))
        return param_dict_list


class TorchRngMixin:

    @property
    def torch_rng(self):
        if not hasattr(self, '_torch_rng'):
            self._torch_rng = torch.Generator(device=self.device if hasattr(self, 'device') else 'cpu')
            if hasattr(self, '_torch_rng_state'):
                self._torch_rng.set_state(self._torch_rng_state)
                del self._torch_rng_state
            else:
                try:
                    self._torch_rng.manual_seed(self._seed)
                except:
                    raise ValueError('please call `set_torch_seed` first to set seed.')
        return self._torch_rng

    @property
    def torch_rng_cpu(self):
        if not hasattr(self, '_torch_rng_cpu'):
            self._torch_rng_cpu = torch.Generator(device='cpu')
            if hasattr(self, '_torch_rng_cpu_state'):
                self._torch_rng_cpu.set_state(self._torch_rng_cpu_state)
                del self._torch_rng_cpu_state
            else:
                self._torch_rng_cpu.manual_seed(
                    int(torch.randint(
                        high=MAX_INT,
                        size=(1,),
                        generator=self.torch_rng,
                        device=self.device if hasattr(self, 'device') else 'cpu').to('cpu')))
        return self._torch_rng_cpu
        
    def set_torch_seed(self, seed):
        if hasattr(self, '_torch_rng'):
            self.torch_rng.manual_seed(int(seed))
            self.torch_rng_cpu.manual_seed(
                int(torch.randint(
                    high=MAX_INT,
                    size=(1,),
                    generator=self.torch_rng,
                    device=self.device if hasattr(self, 'device') else 'cpu').to('cpu')))
        else:
            self._seed = int(seed)

    def gen_seed(self):
        return int(torch.randint(high=MAX_INT, size=(1,), generator=self.torch_rng_cpu))

    def delete_rng(self):
        if hasattr(self, '_torch_rng'):
            self._torch_rng_state = self._torch_rng.get_state()
            del self._torch_rng
        if hasattr(self, '_torch_rng_cpu'):
            self._torch_rng_cpu_state = self._torch_rng_cpu.get_state()
            del self._torch_rng_cpu



class OptimizerMixin:

    def init_optimizer(self,
                       optimizer='Adagrad',
                       optimizer_kwargs={'lr': 1e-2}):
        self.optimizer = getattr(torch.optim, optimizer)(
            params=self.parameters(),
            **optimizer_kwargs)


def construct_dataset(file_path=None, mol_col=None, tgt_col=None, preprocessing_list=[]):
    if file_path:
        mol_df = pd.read_csv(file_path).loc[
            :, [mol_col, tgt_col]]
        mol_list = [Chem.MolFromSmiles(each_smiles) for each_smiles in mol_df.iloc[:, 0]]
        tgt_list = mol_df.iloc[:, 1].tolist()
    else:
        mol_list = []
        tgt_list = []
    for each_preprocessor in preprocessing_list:
        tgt_list = [eval(each_preprocessor)(each_tgt) for each_tgt in tgt_list]
    return mol_list, tgt_list


def mol2hash(mol):
    return Chem.MolToInchiKey(mol)


class DeviceContext(torch.cuda.device):

    def __init__(self, device):
        if torch.cuda.is_available():
            super().__init__(device)
        else:
            pass

    def __enter__(self):
        if torch.cuda.is_available():
            super().__enter__()
        else:
            pass

    def __exit__(self, type, value, traceback):
        if torch.cuda.is_available():
            super().__exit__(type, value, traceback)
        else:
            pass

def device_count():
    if torch.cuda.is_available():
        return torch.cuda.device_count()
    else:
        return 1

def device_name(device_idx):
    if torch.cuda.is_available():
        return 'cuda:{}'.format(device_idx % device_count())
    else:
        return 'cpu'
